{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Reinforcement Learning in Action\n",
    "### by Alex Zai and Brandon Brown\n",
    "\n",
    "#### Chapter 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 3.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Gridworld import Gridworld\n",
    "game = Gridworld(size=4, mode='static')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "game.display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([['+', '-', ' ', ' '],\n",
       "       [' ', 'W', ' ', ' '],\n",
       "       [' ', ' ', ' ', ' '],\n",
       "       [' ', ' ', ' ', 'P']], dtype='<U2')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.makeMove('d')\n",
    "game.makeMove('d')\n",
    "game.makeMove('d')\n",
    "game.display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.reward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 1]],\n",
       "\n",
       "       [[1, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[0, 1, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]],\n",
       "\n",
       "       [[0, 0, 0, 0],\n",
       "        [0, 1, 0, 0],\n",
       "        [0, 0, 0, 0],\n",
       "        [0, 0, 0, 0]]], dtype=uint8)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.board.render_np()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 4, 4)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "game.board.render_np().shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 3.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from Gridworld import Gridworld\n",
    "import random\n",
    "from matplotlib import pylab as plt\n",
    " \n",
    "l1 = 64\n",
    "l2 = 150\n",
    "l3 = 100\n",
    "l4 = 4\n",
    " \n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(l1, l2),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(l2, l3),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(l3,l4)\n",
    ")\n",
    "loss_fn = torch.nn.MSELoss()\n",
    "learning_rate = 1e-3\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    " \n",
    "gamma = 0.9\n",
    "epsilon = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_set = {\n",
    "    0: 'u',\n",
    "    1: 'd',\n",
    "    2: 'l',\n",
    "    3: 'r',\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 3.3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 1000\n",
    "losses = []\n",
    "for i in range(epochs):\n",
    "    game = Gridworld(size=4, mode='static')\n",
    "    state_ = game.board.render_np().reshape(1,64) + np.random.rand(1,64)/10.0\n",
    "    state1 = torch.from_numpy(state_).float()\n",
    "    status = 1\n",
    "    while(status == 1):\n",
    "        qval = model(state1)\n",
    "        qval_ = qval.data.numpy()\n",
    "        if (random.random() < epsilon):\n",
    "            action_ = np.random.randint(0,4)\n",
    "        else:\n",
    "            action_ = np.argmax(qval_)\n",
    "        \n",
    "        action = action_set[action_]\n",
    "        game.makeMove(action)\n",
    "        state2_ = game.board.render_np().reshape(1,64) + np.random.rand(1,64)/10.0\n",
    "        state2 = torch.from_numpy(state2_).float()\n",
    "        reward = game.reward() #-1 for lose, +1 for win, 0 otherwise\n",
    "        with torch.no_grad():\n",
    "            newQ = model(state2.reshape(1,64))\n",
    "        maxQ = torch.max(newQ)\n",
    "        if reward == -1: # if game still in play\n",
    "            Y = reward + (gamma * maxQ)\n",
    "        else:\n",
    "            Y = reward\n",
    "        Y = torch.Tensor([Y]).detach().squeeze()\n",
    "        X = qval.squeeze()[action_]\n",
    "        loss = loss_fn(X, Y)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        losses.append(loss.item())\n",
    "        optimizer.step()\n",
    "        state1 = state2\n",
    "        if reward != -1: #game lost\n",
    "            status = 0\n",
    "    if epsilon > 0.1:\n",
    "        epsilon -= (1/epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x11ac4e990>]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAdkElEQVR4nO3deZhU9Z3v8feXVSUqIC1ykdiImIhJRNMajBnH0UQUM4O5N86DmUeJ44RJoom5M5M7aEbjLGY0Rh1NjA5GFIz7FsiICyCKMWFpkH1tFqGlaVrWZu/le/+o01jdVHVVV53qqjr1eT1PP336V6dOfc/p7s/51e8sZe6OiIhES5d8FyAiIuFTuIuIRJDCXUQkghTuIiIRpHAXEYmgbvkuAKBfv35eXl6e7zJERIrKggULPnb3skSPFUS4l5eXU1lZme8yRESKipl9mOwxDcuIiESQwl1EJIIU7iIiEZQy3M3sGDObZ2aLzWy5mf1r0D7YzOaa2Voze97MegTtPYOfq4LHy3O7CiIi0lY6PfdDwKXufg4wHLjCzEYA9wAPuPtQYCdwYzD/jcBOdz8DeCCYT0REOlHKcPeYvcGP3YMvBy4FXgraJwFXB9Ojg58JHr/MzCy0ikVEJKW0xtzNrKuZLQK2AdOBdcAud28MZqkGBgbTA4HNAMHju4GTEixznJlVmlllXV1ddmshIiKtpBXu7t7k7sOBU4ELgLMSzRZ8T9RLP+q+wu4+wd0r3L2irCzhOfih2LR9P7PXaOchIqWlQ2fLuPsu4B1gBNDbzFougjoV2BJMVwODAILHTwR2hFFsJi6+dxbXT5yXr5cXEcmLdM6WKTOz3sH0scBXgZXALOCbwWxjgSnB9NTgZ4LH33Z9IoiISKdK5/YDA4BJZtaV2M7gBXf/HzNbATxnZv8BfAA8Hsz/OPCUmVUR67GPyUHdIiLSjpTh7u5LgHMTtK8nNv7etv0gcE0o1YmISEZ0haqISAQp3EVEIkjhLiISQQp3EZEIUriLiESQwl1EJIIU7iIiEaRwFxGJIIW7iEgEKdxFRCKo5ML94VlVlI9/Dd3LTESiLHLhfrChib2HGpM+ft9bqwFoVraLSIRFLtyveug9PvfTN/NdhohIXkUu3NfV7ct3CSIieRe5cBcRkRIP91+/Ezu4eqixKd+liIiEqqTDfcLs9QDsP6RwF5FoKelwFxGJKoW7iEgEKdxFRCJI4S4iEkElG+66/YCIRFnJhbuZ5bsEEZGcK7lwFxEpBQp3INUAzeHGZh6eVcXhxuZOqUdEJFspw93MBpnZLDNbaWbLzeyWoP1OM/vIzBYFX6PinnOrmVWZ2WozG5nLFchGugM0T7y/gXvfXM0T72/IaT0iImHplsY8jcA/uvtCMzseWGBm04PHHnD3X8TPbGbDgDHA2cD/AmaY2ZnuXrSXge47HCv9QEPRroKIlJiUPXd3r3H3hcF0PbASGNjOU0YDz7n7IXffAFQBF4RRrIiIpKdDY+5mVg6cC8wNmm42syVmNtHM+gRtA4HNcU+rJsHOwMzGmVmlmVXW1dV1uHAREUku7XA3s08BLwM/cvc9wCPAEGA4UAPc1zJrgqcfdczS3Se4e4W7V5SVlXW4cBERSS6tcDez7sSC/Wl3fwXA3Wvdvcndm4HH+GTopRoYFPf0U4Et4ZUsIiKppHO2jAGPAyvd/f649gFxs30DWBZMTwXGmFlPMxsMDAXmhVdybu091Mhv53yoK1hFpKilc7bMRcB1wFIzWxS03QZca2bDiQ25bAT+HsDdl5vZC8AKYmfa3FSIZ8rER3d8kP/r1OW8uKCawf16cdEZ/Vo/R3kvIkUiZbi7+x9IPI4+rZ3n3AXclUVdORO/IoluRbBj32EADhz+ZH+kGxaISLHRFaoiIhGkcBcRiSCFu4hIBCncRUQiSOEuIhJBCvckdNajiBQzhXsb7X1QkwJfRIqFwp3Uoa1P5hORYlPS4a7MFpGoKtlw160ERCTKSi7ckw2xvFi5mXV1ezu3GBGRHEnnxmEl4ccvLaFHty5cPLRf6plFRApcyfXc23O4sfnItG75KyLFrOTCPXVm6zCriBS/kgv3Fhmd3qjevIgUiZIN93ipMtvUmxeRIlPS4a6Lk0Qkqko63DXKIiJRVdLh3kI9eBGJmpINd/XaRSTKSi7c0+2lK/tFpJiVXLinolv+ikgUKNzToDF5ESk2CncRkQhKGe5mNsjMZpnZSjNbbma3BO19zWy6ma0NvvcJ2s3MHjKzKjNbYmbn5XolspXuwdWmZmfDx/tyW4yISAjS6bk3Av/o7mcBI4CbzGwYMB6Y6e5DgZnBzwBXAkODr3HAI6FXHZKODrf8+p11/MUv3mHT9v25KUhEJCQpw93da9x9YTBdD6wEBgKjgUnBbJOAq4Pp0cBkj5kD9DazAaFXnkd1ew8B0NjUnGJOEZH86NCYu5mVA+cCc4H+7l4DsR0AcHIw20Bgc9zTqoO2tssaZ2aVZlZZV1fX8cpzLNVQzeLNuzjjJ6/z7prCq11EJO1wN7NPAS8DP3L3Pe3NmqDtqKh09wnuXuHuFWVlZemWkXPpjtTM37gDgHdXK9xFpPCkFe5m1p1YsD/t7q8EzbUtwy3B921BezUwKO7ppwJbwilXRETSkc7ZMgY8Dqx09/vjHpoKjA2mxwJT4tqvD86aGQHsbhm+KSTegUuSdJq7iBSbdD5D9SLgOmCpmS0K2m4D7gZeMLMbgU3ANcFj04BRQBWwH7gh1IqzFLs3e7bXmupaVREpbCnD3d3/QPLO62UJ5nfgpizr6lQd6cWLiBSDEr9CNb0Bl/aiXzsGESlEJR7u7Wk/tE03nBGRAqZwb0OZLSJRoHAXEYkghXsG9ClOIlLoFO5tJApujdSISLFRuCelSBeR4qVwz5KGaESkEJVsuLcK5QwCWv16ESlkpRfucanc/mmP6pKLSPEqvXBPIZ3z3BX7IlLoFO4iIhGkcE+DrloVkWKjcBcRiSCFe4bUmxeRQqZwFxGJoNIL9wSnuiQ6+6W9i5N04ZKIFLrSC/c4iUZWTJcniUgElHS4i4hEVUmHexijK64xGhEpQKUX7glGXVINxCT6SD0N3ohIISu9cA+J+usiUsgU7lnSB2WLSCFSuCfRXs9c4+wiUuhShruZTTSzbWa2LK7tTjP7yMwWBV+j4h671cyqzGy1mY3MVeFhio9qdcRFJArS6bk/CVyRoP0Bdx8efE0DMLNhwBjg7OA5vzazrmEVGzbluIhEVcpwd/fZwI40lzcaeM7dD7n7BqAKuCCL+gpS/Di7hmhEpBBlM+Z+s5ktCYZt+gRtA4HNcfNUB21HMbNxZlZpZpV1dXVZlJEf6vWLSCHLNNwfAYYAw4Ea4L6gPVHmJezauvsEd69w94qysrIMyxARkUQyCnd3r3X3JndvBh7jk6GXamBQ3KynAluyKzE3shlN0VCMiBS6jMLdzAbE/fgNoOVMmqnAGDPraWaDgaHAvOxKDJeGU0SkFHRLNYOZPQtcAvQzs2rgp8AlZjac2JDLRuDvAdx9uZm9AKwAGoGb3L0pN6XnljrnIlLMUoa7u1+boPnxdua/C7grm6LyqaPnuWsfICKFSFeoklkvXbcdEJFCVtLhrnwWkagq6XBPV9udgIZiRKTQKdyzNPlPH1K9c3++yxARaUXhHoJlH+3JdwkiIq0o3JNwDb6ISBFTuLdhaV7mpIOxIlLISjbcw+2Zq5cvIoWl5MI9UY+7o0Gvq1dFpNCVXLjHS3cIRkSk2JR0uIuIRJXCPQQtwzRvLNvKG8u25rcYERHSuHFYqUo1rp5oQOe7v10AwMa7rwq/IBGRDlDPXUQkghTuSbR3HrvrEicRKXAK9xAo6EWk0CjcRUQiSOGOLkoSkegp2XB3T//+MKkudrr1laUhVCQiEp6SC/ewrkqNX8ruAw2hLFNEJCwlF+7pnufS7lCNhnFEpMCVXLinlGHHvmpbfbh1iIhkoaTDPcwDqV+9f3Z4CxMRyVJJh3sLffCGiERNynA3s4lmts3MlsW19TWz6Wa2NvjeJ2g3M3vIzKrMbImZnZfL4gEenlXFD579INcvIyJSVNLpuT8JXNGmbTww092HAjODnwGuBIYGX+OAR8IpM7HGpmbufXM1v1+8JavlpLxJWKKevbr7IlLAUoa7u88GdrRpHg1MCqYnAVfHtU/2mDlAbzMbEFaxbTVlOWiufBaRqMp0zL2/u9cABN9PDtoHApvj5qsO2o5iZuPMrNLMKuvq6jIsI3d0tqOIFLOwD6gm6gsnzEl3n+DuFe5eUVZWFnIZmVNnXkSiINNwr20Zbgm+bwvaq4FBcfOdCmQ3IJ4j6pmLSJRlGu5TgbHB9FhgSlz79cFZMyOA3S3DN7mQya0E9KHYIlIKUn7Mnpk9C1wC9DOzauCnwN3AC2Z2I7AJuCaYfRowCqgC9gM35KDmgqBdhIgUspTh7u7XJnnosgTzOnBTtkUVOg3piEih0xWqaQirlz5h9jom/XFjSEsTEUkuZc+9kIV1nnrYPfHaPQfpf8IxR7X/bNoqAMZ+uTzkVxQRaa2ke+7t7Rs8iwukvvSzmSz7aHfGzxcRyVZJh3siFtLbgXV1e0NZjohIJoo63HXGiohIYkUd7vnSkc9fFRHJB4W7iEgElWy4Z3PAVESk0JVcuGcynJLoOdo3iEghK+pwDytf1YsXkagp6nDPVqLTHsM6Trp97+GQliQi0nElHe7ZSDW882//s6JzChERSUDhngHXrcNEpMAp3EVEIkjhLiISQQr3NOjTm0Sk2Cjck9DZkSJSzEo23JNld7oXOak3LyKFrOTCPVEkp+qltz07Rr16ESl0JRfuIiKlQOHehnrlIhIFCvckdL92ESlmCncRkQhSuKch0Zkx6tmLSCEruXBPd0i9vbH3MIblpy2tYVv9wRCWJCJytKzC3cw2mtlSM1tkZpVBW18zm25ma4PvfcIptXN0Ro9836FGvv/0Qq77zbzcv5iIlKQweu5/4e7D3b0i+Hk8MNPdhwIzg59zoljPbGlsjhW+ZdeBPFciIlGVi2GZ0cCkYHoScHUOXiOvQuvca9xeRHIk23B34C0zW2Bm44K2/u5eAxB8PznRE81snJlVmlllXV1dlmV0XLH2+kVE0tEty+df5O5bzOxkYLqZrUr3ie4+AZgAUFFR0WlRq86yiJSCrHru7r4l+L4NeBW4AKg1swEAwfdt2RZZaBztJESksGUc7mbWy8yOb5kGLgeWAVOBscFsY4Ep2RaZD/E3Cwv9DBoNCYlIjmUzLNMfeNViydcNeMbd3zCz+cALZnYjsAm4JvsyO09n9sjV+xeRXMk43N19PXBOgvbtwGXZFFUq9hxs5Kk/beS6C8vzXYqIREzJXaEalmxGVuKHfG6fsjz7YkRE2lC4i4hEUEmHe6YHSt1d4+UiUtBKOtxFRKJK4Z5ELq9g1dWxIpJrpRvuSQLW8nCj9odnVXX6a4pItJVcuOcjvFN55J11+S5BRCKm5MK9EGhURkRyTeGeAUcfsyciha2ow91D6gPn+wCn57sAEYmcog73bOWr960wF5FcK+lwLxSKehEJm8IdeHDm2qPa4jvXhXiGjYhIexTuwMsLq49Md0aMq6cuIrlWMuFevXN/eAsLOZ01BB+zfe8hHY8QCUnJhPvbq8L9tD/Loo/fXn69tqSGFVv2ZLzsYrVp+36++B8zeOy99fkuRSQSSibc2wrrNMowHGhoOjJ90zMLGfXQe3msJj82B++sZq2qy3Ml+dXU7Dw150MamprzXYoUuWw+Zq+guHtaBz4L4dBooh1L9c79NDQVzg6nsxXC76UQPD9/M7f/bhn1Bxv4/iVn5LscKWKRCfdCdOBwrEd+bI+urR9IkOFfuWdWJ1SUGzNW1FJ2fE/OGdQ762UV0juqfNhzsAGAXfsb8lyJFLuSCfc7pixnZU3rsez2xs3DiJjz/n06Bxqa2Hj3VaEvu5D83eRKgKPWsyNa3nWV+vHU+L/I3y/ewsGGJq6pGJS3eqR4RWbMPZ1QeHbe5tQzBf9dtXsOtm365LXSjOf4sfR4/zVjTVrPT8Td+eO6j/N6Vom7887qbaHW0DKiVuLZ/sl2cOcHz37Aj19akt+CpGhFJtzDdu+bq5M+Nnf9DhqbM4+h3y+uSTnPWbe/waPvruO+t1a3CtEXF1Tzrcfm8rtFH2X02jNW1PL60tSvn0hjUzNTFn3ElEVb+PYT8/nt3E1Hhp5CU+Lp3vJushDfwTw/fxPl4187MnQk6Zu+opZN20M8HTsNCvc2mtII7Vc/+Ig7pizL+DW6dkl9+PBAQxN3v76KX75dxcqaei697x1Wb62nekfsD2TOuh24O4s27+LD7fsSLqOp2Y9an7+bXMn3nl6YUd2PvbeBW55bxH/Pjp2uuGn7Ps66441Wr5ep+C1SvXM/y7fsZtlHuzNeXrFa/3Hi32Uh+M17GwDYuvtgijnT19jUzD+9uJh1dXtDW2aheW1JDd+ZXMml973Tqa8bmXDvaKw0NDmbdhy9J52yaEvK526rP5RVz333gY71fJ6e+yHr6/Yx8r9m89DbsU9ter5yM0/N+ZCrH36fP7/3nSPz7jnYcOQ8+SG3TWPIbdM42NDEtvrW/5Dl41/jM//yOgcTDB01NzsPzljLzn2Hj7TdOXU597yxCoBd+2PtbbfB/I07OrReiayurecr98ziqof+wNd/+QcA9h5qZEdcLWFZWr2bGStqQ19uNp6dtwmApriu+wPTkw/jTV9Ry5LqXbg7s1Zty2oHG2/5lt1ccu+shH+rze6hnaq5bMseXlpQzf99flEoyytENz0T60xlkxmZyNkBVTO7AngQ6Ar8xt3vztVrZSKdK1Z/9fZabr50KGtq67N+vWzGp5+euylh+x1Tlh+ZLh//Gud+ujdVtXupP9TIZZ89+chjX7nnbT7ee3Q4Hmps5rO3v8G7P76E70yuZE3tXt4ffymravbwwIw1PDBjDUPKerGurnVvsibouT3x/sZW7fsPNwJw1UPvccNFg/nmF0896jWnLt7CD5/9gJe/dyHH9ejGKSccQ59ePY483jZMHpyxlgeCYxQb/nNU2vf5adneyeY/1NjEX/4qtvP42rD+PHZ9xZHHXltSw/nlfTj5hGMSPnfzjv24w6dPOg6IbXuAV77/Zco+1ZNBfY9Lq8Z4M1bUtuqgxO9YH5y5lmN7dOXLQ07iC6f2Zm1tPSce252TTziG7wQHs5+84XxueHI+//C1M/nhZUNbLbuhqZn7p6/huxcP4cTjuqes5fWlNUfe3c1Zv52RZ58CfHI84NZXlvLBpl1svPsqllbv5nMDT8j4/ktdg+els1Oqqz/Etx6bw8Rvn5/RNi41losDc2bWFVgDfA2oBuYD17r7ikTzV1RUeGVlZYdfZ8e+w5z379OTPn7B4L7M25Bdb/L/XfEZfv5G8vH3dF31hQG8tiSzse5i8udnlvHumtYXIv38m1+gqxn7Dje22iFl4sExw7nluUUc270rT9xwPmMmzDny2DmDenPisd2ZvSb5hVA/vGwo67bt5bUExx1O6tWD7XGh2rdXD3p268Ivrz2Xf/ndMlZtrefcT/fmg027APj8wBNZmmTo6Ftf+jTPJNkp3zbqs/xs2qq01retp268gOsen9fuPCNO78uc9Tt45jtfolePbry+bCuPvhv7KMfLh/VnxOknUX+wkS4GfzPiNLp2MT7YtJNn523izeWt38kM7teLDR/v4weXnsEv3078Wb8/vGwooz5/CtU7DjBn/XYuHHISp5x4DF3M6NmtC6u31vO9pxdy+9eHMfLs/vxs2kqmLd3KL645h97Hdj9yttXEb1dwsKGZOeu389Wz+tPkzj+/tIRjusd2bOs/3se8DTuO/J5+e+OX+OJpfZj4/gb+umIQ5981gxOO6cbin17O4urdTP7jRm676ixOOKY7uw800KtnV3p260rXLsaqrXuYv3En1404DYjtABuamnGP3Wtq9DkD6dIFHpq5li8P6ceFQ04C4HBTM4cbm+liRheDlTX1jDi9Lytr6tm5/zC79jdQ3u84BvY+lh7durCkenerv9GRZ/dnwInHcvGZ/fj8wN7U7jnI5waemNHfAoCZLXD3ioSP5SjcLwTudPeRwc+3Arj7fyaaP9Nwv/v1VUf+aEVEitHIs/vz39clzOeU2gv3XI25DwTizzusDtriixpnZpVmVllXl9kl5+Ov/GzmFYqIFIDLh52Sk+Xmasw90QBcq7cI7j4BmACxnnumL5TNhTMiIlGVq557NRB/Wd2pQOrTUEREJBS5Cvf5wFAzG2xmPYAxwNQcvZaIiLSRk2EZd280s5uBN4mdCjnR3bM7TUJERNKWs/Pc3X0aMC1XyxcRkeQic4WqiIh8QuEuIhJBCncRkQhSuIuIRFBObj/Q4SLM6oAPM3x6P+DjEMvJpWKpVXWGr1hqVZ3hy2Wtp7l7WaIHCiLcs2FmlcnurVBoiqVW1Rm+YqlVdYYvX7VqWEZEJIIU7iIiERSFcJ+Q7wI6oFhqVZ3hK5ZaVWf48lJr0Y+5i4jI0aLQcxcRkTYU7iIiEVTU4W5mV5jZajOrMrPxeXj9QWY2y8xWmtlyM7slaL/TzD4ys0XB16i459wa1LvazEZ21rqY2UYzWxrUUxm09TWz6Wa2NvjeJ2g3M3soqGWJmZ0Xt5yxwfxrzWxsDur8TNx2W2Rme8zsR4WwTc1sopltM7NlcW2hbUMz+2LwO6oKnpvRp04nqfNeM1sV1PKqmfUO2svN7EDcdn00VT3J1jnEWkP7XVvstuNzg1qft9gtyMOq8/m4Gjea2aKgPa/b9Ah3L8ovYrcSXgecDvQAFgPDOrmGAcB5wfTxxD4UfBhwJ/BPCeYfFtTZExgc1N+1M9YF2Aj0a9P2c2B8MD0euCeYHgW8TuwTtUYAc4P2vsD64HufYLpPjn/HW4HTCmGbAhcD5wHLcrENgXnAhcFzXgeuDLHOy4FuwfQ9cXWWx8/XZjkJ60m2ziHWGtrvGngBGBNMPwp8L6w62zx+H3BHIWzTlq9i7rlfAFS5+3p3Pww8B4zuzALcvcbdFwbT9cBK2nxWbBujgefc/ZC7bwCqiK1HvtZlNDApmJ4EXB3XPtlj5gC9zWwAMBKY7u473H0nMB24Iof1XQasc/f2rl7utG3q7rOBHQleP+ttGDx2grv/yWP/4ZPjlpV1ne7+lrs3Bj/OIfbpaEmlqCfZOodSazs69LsOesWXAi9lW2t7dQav89fAs+0to7O2aYtiDveUH8LdmcysHDgXmBs03Ry8BZ4Y9xYrWc2dsS4OvGVmC8xsXNDW391rILajAk4ugDrjjaH1P0yhbVMIbxsODKZzXS/A3xLrNbYYbGYfmNm7ZvZnQVt79SRb5zCF8bs+CdgVt1PL1Tb9M6DW3dfGteV9mxZzuKf8EO7OYmafAl4GfuTue4BHgCHAcKCG2Fs2SF5zZ6zLRe5+HnAlcJOZXdzOvPmsM1ZAbGz0r4AXg6ZC3Kbt6WhdnVKvmf0EaASeDppqgE+7+7nAPwDPmNkJnVVPEmH9rjtrHa6ldSekILZpMYd7QXwIt5l1JxbsT7v7KwDuXuvuTe7eDDxG7G0jJK855+vi7luC79uAV4OaaoO3ii1vGbflu844VwIL3b02qLvgtmkgrG1YTeuhktDrDQ7efh34m2BYgGCIY3swvYDY2PWZKepJts6hCPF3/TGx4bBubdpDEyz7fwPPx9VfENu0mMM97x/CHYy1PQ6sdPf749oHxM32DaDlCPtUYIyZ9TSzwcBQYgdYcrouZtbLzI5vmSZ2cG1Z8BotZ2uMBabE1Xm9xYwAdgdvFd8ELjezPsFb5cuDtlxo1RsqtG0aJ5RtGDxWb2Yjgr+r6+OWlTUzuwL4Z+Cv3H1/XHuZmXUNpk8ntv3Wp6gn2TqHVWsov+tgBzYL+GauagW+Cqxy9yPDLQWzTbM9IpvPL2JnJKwhtmf8SR5e/yvE3lYtARYFX6OAp4ClQftUYEDcc34S1LuauLMhcrkuxM4iWBx8LW9ZPrExyZnA2uB736DdgIeDWpYCFXHL+ltiB7KqgBtytF2PA7YDJ8a15X2bEtvZ1AANxHphN4a5DYEKYkG2DvgVwRXkIdVZRWxcuuXv9NFg3v8T/E0sBhYCf5mqnmTrHGKtof2ug7/9ecH6vwj0DKvOoP1J4Ltt5s3rNm350u0HREQiqJiHZUREJAmFu4hIBCncRUQiSOEuIhJBCncRkQhSuIuIRJDCXUQkgv4/bIsVuDDxPJAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = torch.Tensor([2.0])\n",
    "m.requires_grad=True\n",
    "b = torch.Tensor([1.0])\n",
    "b.requires_grad=True\n",
    "def linear_model(x,m,b):\n",
    "    y = m @ x + b\n",
    "    return y\n",
    "y = linear_model(torch.Tensor([4.]), m,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([9.], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<AddBackward0 at 0x11ad4c6d0>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y.grad_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    y = linear_model(torch.Tensor([4]),m,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([9.])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "y.grad_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([4.])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = linear_model(torch.Tensor([4.]), m,b)\n",
    "y.backward()\n",
    "m.grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 3.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_model(model, mode='static', display=True):\n",
    "    i = 0\n",
    "    test_game = Gridworld(mode=mode)\n",
    "    state_ = test_game.board.render_np().reshape(1,64) + np.random.rand(1,64)/10.0\n",
    "    state = torch.from_numpy(state_).float()\n",
    "    if display:\n",
    "        print(\"Initial State:\")\n",
    "        print(test_game.display())\n",
    "    status = 1\n",
    "    while(status == 1):\n",
    "        qval = model(state)\n",
    "        qval_ = qval.data.numpy()\n",
    "        action_ = np.argmax(qval_)\n",
    "        action = action_set[action_]\n",
    "        if display:\n",
    "            print('Move #: %s; Taking action: %s' % (i, action))\n",
    "        test_game.makeMove(action)\n",
    "        state_ = test_game.board.render_np().reshape(1,64) + np.random.rand(1,64)/10.0\n",
    "        state = torch.from_numpy(state_).float()\n",
    "        if display:\n",
    "            print(test_game.display())\n",
    "        reward = test_game.reward()\n",
    "        if reward != -1: #if game is over\n",
    "            if reward > 0: #if game won\n",
    "                status = 2\n",
    "                if display:\n",
    "                    print(\"Game won! Reward: %s\" % (reward,))\n",
    "            else: #game is lost\n",
    "                status = 0\n",
    "                if display:\n",
    "                    print(\"Game LOST. Reward: %s\" % (reward,))\n",
    "        i += 1\n",
    "        if (i > 15):\n",
    "            if display:\n",
    "                print(\"Game lost; too many moves.\")\n",
    "            break\n",
    "    win = True if status == 2 else False\n",
    "    return win"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial State:\n",
      "[['+' '-' ' ' 'P']\n",
      " [' ' 'W' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']]\n",
      "Move #: 0; Taking action: l\n",
      "[['+' '-' 'P' ' ']\n",
      " [' ' 'W' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']]\n",
      "Move #: 1; Taking action: d\n",
      "[['+' '-' ' ' ' ']\n",
      " [' ' 'W' 'P' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']]\n",
      "Move #: 2; Taking action: d\n",
      "[['+' '-' ' ' ' ']\n",
      " [' ' 'W' ' ' ' ']\n",
      " [' ' ' ' 'P' ' ']\n",
      " [' ' ' ' ' ' ' ']]\n",
      "Move #: 3; Taking action: l\n",
      "[['+' '-' ' ' ' ']\n",
      " [' ' 'W' ' ' ' ']\n",
      " [' ' 'P' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']]\n",
      "Move #: 4; Taking action: l\n",
      "[['+' '-' ' ' ' ']\n",
      " [' ' 'W' ' ' ' ']\n",
      " ['P' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']]\n",
      "Move #: 5; Taking action: u\n",
      "[['+' '-' ' ' ' ']\n",
      " ['P' 'W' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']]\n",
      "Move #: 6; Taking action: u\n",
      "[['+' '-' ' ' ' ']\n",
      " [' ' 'W' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']]\n",
      "Game won! Reward: 10\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_model(model, 'static')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 3.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import deque\n",
    "epochs = 5000\n",
    "losses = []\n",
    "mem_size = 1000\n",
    "batch_size = 200\n",
    "replay = deque(maxlen=mem_size)\n",
    "max_moves = 50\n",
    "h = 0\n",
    "for i in range(epochs):\n",
    "    game = Gridworld(size=4, mode='random')\n",
    "    state1_ = game.board.render_np().reshape(1,64) + np.random.rand(1,64)/100.0\n",
    "    state1 = torch.from_numpy(state1_).float()\n",
    "    status = 1\n",
    "    mov = 0\n",
    "    while(status == 1): \n",
    "        mov += 1\n",
    "        qval = model(state1)\n",
    "        qval_ = qval.data.numpy()\n",
    "        if (random.random() < epsilon):\n",
    "            action_ = np.random.randint(0,4)\n",
    "        else:\n",
    "            action_ = np.argmax(qval_)\n",
    "        \n",
    "        action = action_set[action_]\n",
    "        game.makeMove(action)\n",
    "        state2_ = game.board.render_np().reshape(1,64) + np.random.rand(1,64)/100.0\n",
    "        state2 = torch.from_numpy(state2_).float()\n",
    "        reward = game.reward()\n",
    "        done = True if reward > 0 else False\n",
    "        exp =  (state1, action_, reward, state2, done)\n",
    "        replay.append(exp)\n",
    "        state1 = state2\n",
    "        \n",
    "        if len(replay) > batch_size:\n",
    "            minibatch = random.sample(replay, batch_size)\n",
    "            state1_batch = torch.cat([s1 for (s1,a,r,s2,d) in minibatch])\n",
    "            action_batch = torch.Tensor([a for (s1,a,r,s2,d) in minibatch])\n",
    "            reward_batch = torch.Tensor([r for (s1,a,r,s2,d) in minibatch])\n",
    "            state2_batch = torch.cat([s2 for (s1,a,r,s2,d) in minibatch])\n",
    "            done_batch = torch.Tensor([d for (s1,a,r,s2,d) in minibatch])\n",
    "            \n",
    "            Q1 = model(state1_batch)\n",
    "            with torch.no_grad():\n",
    "                Q2 = model(state2_batch)\n",
    "            \n",
    "            Y = reward_batch + gamma * ((1 - done_batch) * torch.max(Q2,dim=1)[0])\n",
    "            X = \\\n",
    "            Q1.gather(dim=1,index=action_batch.long().unsqueeze(dim=1)).squeeze()\n",
    "            loss = loss_fn(X, Y.detach())\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            losses.append(loss.item())\n",
    "            optimizer.step()\n",
    " \n",
    "        if reward != -1 or mov > max_moves:\n",
    "            status = 0\n",
    "            mov = 0\n",
    "losses = np.array(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x11ad9ad10>]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3wc5Z0G8OdnybZcg7FFSSiy6RBqBIGQQAKXhHa0wIVc4AIcIZdKQnKcIRAcSDEQCM2GUGJTDdjYBHDHveAiWbZlW7IkS7IsWVa1el3te3/sjLQ7O7M7K+1oX0nP9/Pxx6vZ2d13R9pnZ94qSikQEZG+hiW6AEREFBmDmohIcwxqIiLNMaiJiDTHoCYi0lyyF086adIklZaW5sVTExENSpmZmdVKqVS7+zwJ6rS0NGRkZHjx1EREg5KI7He6j1UfRESaY1ATEWmOQU1EpDkGNRGR5hjURESaY1ATEWmOQU1EpDmtgvr5FflYk1eV6GIQEWlFq6B+afU+bCioTnQxiIi0olVQExFROAY1EZHmGNRERJpjUBMRaY5BTUSkOQY1EZHmtAtqpVSii0BEpBWtglok0SUgItKPVkFNREThGNRERJpjUBMRaY5BTUSkOQY1EZHmXAW1iPxGRHaLyC4RmSMiKV4XjIiIAqIGtYh8CcCvAKQrpb4MIAnAbV4XjIiIAtxWfSQDGCUiyQBGAzjoVYE43oWIKFTUoFZKlQH4G4ASAOUA6pVSy6z7ici9IpIhIhlVVb1bpYXjXYiIwrmp+pgA4AYAkwF8EcAYEbndup9S6hWlVLpSKj01NTX+JSUiGqLcVH38G4AipVSVUqoTwHwAX/O2WEREZHIT1CUALhaR0SIiAK4EkONtsYiIyOSmjnozgHkAtgHINh7zisflIiIiQ7KbnZRSjwJ41OOyEBGRDY5MJCLSnHZBzW7UREShtApq4coBRERhtApqIiIKx6AmItIcg5qISHMMaiIizTGoiYg0x6AmItIcg5qISHPaBTUXDiAiCqVVUHO4CxFROK2CmoiIwjGoiYg0x6AmItIcg5qISHMMaiIizTGoiYg0p11QKy4dQEQUQq+gZkdqIqIwegU1ERGFYVATEWmOQU1EpDkGNRGR5hjURESaY1ATEWlOu6DmfNRERKG0Cmp2oyYiCqdVUBMRUTgGNRGR5hjURESaY1ATEWmOQU1EpDkGNRGR5hjURESa0yqoRdiTmojIylVQi8gRIjJPRHJFJEdELvG6YEREFJDscr/nACxRSt0iIiMAjPawTEREFCRqUIvIeACXAbgTAJRSHQA6vC0WERGZ3FR9TAFQBWCWiGSJyGsiMsa6k4jcKyIZIpJRVVUV94ISEQ1VboI6GcAFAF5SSp0PoBnAVOtOSqlXlFLpSqn01NTUOBeTiGjochPUpQBKlVKbjZ/nIRDcnlCc55SIKETUoFZKHQJwQEROMzZdCWCPF4Vh7zwionBue338EsA7Ro+PQgB3eVckIiIK5iqolVLbAaR7XBYiIrKh1chEIiIKx6AmItIcg5qISHMMaiIizWkX1OxFTUQUSqugZjdqIqJwWgU1ERGFY1ATEWmOQU1EpDkGNRGR5hjURESaY1ATEWlOu6DmdNRERKG0CmrhhNRERGG0CmoiIgrHoCYi0hyDmohIcwxqIiLNMaiJiDSnXVArTnRKRBRCq6Bm5zwionBaBTUREYVjUBMRaY5BTUSkOQY1EZHmGNRERJpjUBMRaU67oOY0p0REobQKas5ySkQUTqugJiKicAxqIiLNMaiJiDTHoCYi0hyDmohIcwxqIiLNaRfU7EZNRBTKdVCLSJKIZInIp94Vhx2piYisYjmjvg9AjlcFISIie66CWkSOA3AtgNe8LQ4REVm5PaN+FsADAPweloWIiGxEDWoRuQ5ApVIqM8p+94pIhohkVFVVxa2ARERDnZsz6ksBXC8ixQDeA3CFiLxt3Ukp9YpSKl0plZ6amhrnYoZr6+xCRUOb56/jRnF1M55ckgvFqf+IyANRg1op9aBS6jilVBqA2wCsVErd7nnJovjJW5n46l9WJLoYAIC739iKmav3oaS2JdFFIaJBSL9+1C5PStfk6VO94uvimTQReSc5lp2VUqsBrPakJOB81EREdrQ7oyYiolAMaiIizTGoiYg0x6AmItIcg5qISHMaBjW7uhERBdMqqNk7j4gonFZBTURE4RjURESaY1ATEWmOQU1EpDkGNRGR5hjURESa0y6oOfc+EVEorYKa05wSEYXTKqiJiCgcg5qISHMMaiIizTGo44gNoUTkBQa1jXmZpdhSVOt6fzaCEpGXYlrcdqj43dwdAIDi6dcmuCRERBqeUQ/26oOKhjbc8fpm1Ld2JrooRDRAaBXU4tGM1LvK6tHU7vPkuWM1c1UB1uVXY8G20kQXhYgGCK2C2gudXX5c98J63PPG1n57zZYOH7JL6/vt9YhocBvwQa2i1JX4jfu37a/rj+IAAH41Jwv//uJ6NLSxeoOI+m7ABfXr64uw+6DeZ6vbDwS+FNo7/QkuCRENBgOu18fjn+5JdBFcU1yol4jiYMCdUTuZvaEIb35enOhiGNixmojiZ8CdUVspFRhwMu2TwJn2f12SltgCBeMJNRHFgXZn1IOhuoAjFYkonrQKagYcEVE4rYK6w+eHr2vgn1GbBs87IaJE0iqoa5o7MD+rLKbH7K9tcbVfPKpUKhrakDZ1IT7YeiDifrwwIKJ40iqoe6M5ytDweA5L31fVBAD40OXw78E+bwkR9Y8BH9T9aZhRiR4tf1nXTkTxxKCOgZm/0Yatm/pa3fLtZ9bglpc29uk5iGjgG/D9qPuTmGfUUfI3XtUt+ZVNcXkeIhrYop5Ri8jxIrJKRHJEZLeI3NcfBTN1+PxYnF3ueL+/HyuChxn5O9iqntflVyFt6kLUtXQkuihEZMNN1YcPwG+VUmcAuBjAz0XkTG+L1eNvy/bip+9sw7r8qv56SUdm3bPbL4eB0pj48pp9AIBdZQ0JLgkR2Yka1EqpcqXUNuN2I4AcAF/yumCmssOtAOC4IoqXYVjd1G7ZErnqw9w8UBsTB8OoUKLBKKbGRBFJA3A+gM02990rIhkiklFV1X9nv26jxW2gZ5Uc7r7d2RU6TalT1YdTLg+U2JMoX0BElFiug1pExgL4EMCvlVJh18hKqVeUUulKqfTU1NS4FbA3Z3n1LZ2obQ7Ut8Z6dnvTTOdeFj2NiZHLNNBOqGWQ1r0TDRauglpEhiMQ0u8opeZ7WySHMjjEn11onvvYMlzw+HK0dXahxOXIRTev2dM9z93jzbKtzK1Ah895EQGlVPdiA0REVm56fQiA1wHkKKWe8b5IsbFm5tq8nmqX+z/YjiufXtPr57aejfcMeIlyRh30wK3Ftbh7dgaeWJLruP+sDcW4ccaGhDWYrsuvBgB0+bkiDZGO3JxRXwrgDgBXiMh24981HperW7SzV+v9//XPLd231+ztW/BZz+G7e324zDOl0F0FE+nMPq+iEQBQajScJkpLR1dCX5+I7EUd8KKUWo8EVruaQdzXnhQ+v0JhVROmpI6N+bHvbi7B+FHJSJs4JlCmGB4bSwNdohvzuvyspSbS0YAZQu6U0+bZqBtXxFoNYrzoQwuy8Yt3s3qqPnqRqJG+Z3oa85yft6ndh9/N3eHYTTEeItWjE1HiDJigdlJe51xd0NzHS/mwxkQzUGPK6ficpb6xsRjzMku7B6fEQ31LJ+pbeoLfxzNqIi1pOddHh8+PEcmB75BoDXf9mS3F1c0A3DQm9tyOV9WNeRYfzzqocx9bFvKzr4tn1EQ60vKMetmeQ923owVdPEfTNVnmtra+5k/f2dZdJl+XH7+ak4V8m6oXs1Fw+uJcdBjhF3miJvcDTrwc9djRj6vrrMmrwi/nZPXb6xENZFqeUdudJTe22S8QEM8GOOukRE6Z6FcKuw824OMdB1Fc0+z4fAuzy7H7YH3U13UTvv3R0Nibuvfe+pHRO+eFH5zfb69JZKesrhWVDW04/4QJiS6KIy3PqIMt21MBAPjfeTtt7/cyWsQhQVUMr1vTFH2EpN/4ZqpoaIv6fDNWxa+OmoiAS6evjDgiWQfaB3U0sU5z+rN3Mvt+5hj0cGv+Wut5nV4pePumwhoAwAsrCwAAFzy+HFc/t87V88RTf04ZS/ZyyhvwYaa7pd5o6NAyqF9dW4iDEXpzBIs1WxZlH3LsDWI9gy6scjdxv9lbQimF33ywI+S+1s4u47lDX+Oj7Qe797G+hdrmDuSUh06n0h8ZGq1hdmtxbb9WjwxFVz+3Dr+duyP6jjSkaBnU2WX1+PGbGa729TI4DkWoili2u6fB02w8nPbJHnyy42DIfsGDSDp8fszeWAwA2BHj3B5OjaYfZBzAhoLqmJ7L8TUiHMrleypw68uf4+1N++PyWkTknpZBDQCtLvtAe3mCF6mnxszVgbri4J4iwfOMWCkV+f7eemDeTvzwtbBZZ3slUg+aF1bmAwD2VTk3nupmfX51XPudEyWKtkHtthtaX3N6x4G6Pp2Vuw2uxbsOhW3r8iukTV2I/TXRZ/irbLQuYhB/kQ7DztLovVec+Lr8mL2hKGx+b6/d/vpmTF+cy/7hNOBpHNTukjqWIeRWi7LLccOMDViQVRZ4TZt9Sg/3fprUaNo63Y+c3Be00O3hZvu1DedsKXG8zw2/i9FDsX6pNbR1YvbGYkz7ZA/eMKp9+pvuLfpE0Wgb1MNsUnPvofBQXpdf3etFWc3Gwn0OjYYiwNefWNWr53bjs5yKXj3urtlbw7blVTTiwfnZuO/97b0uj5sIjiWmG9o6cc60ZfjTwhwAzn3h3Xh1bSFOfXhxrx6bXdb7qwEiHWgb1Hb1w999dq3tvuc9ttyTMqzMrYzr891jaSC1joSMJDggtx+oQ6WlobO9M3B5X1LTHHaf1ZaiWqRNXRi2Pd7d84LnEQF65vN20trR5VhN8edFOQNq0qiDda3sIUNxo29QezhUut2hysH6mvN06s9q+cxPXxy6EIFZ9uKaFlz0lxURn2pNnv0XkJtciSV7rMEf7Xd6xh+W4L/fcNfbx0sH61qRuf9w9B0d7DnYgK9NX9ndw4eorzQOau+SetaGYgA9/Z/3HnLXXzqRthTXhvw8P6sMu4Iu6a97Yb3t45RSeOSjXSGL9joua+aiHLHMrWINdbvqLKs1LnrGdPkV2n3eLXJw+VOr8L2XQuu1KxvbXJ8h7zemFdhcWBtlTyJ3tA1qNx/q3qpp7sD2A3XYczAwqKS3dcWJ5hTOwdp9fry1aT9umrkR97+/HUop5wmu4nypHn5GHZ9f6p2ztuC0h5d0/5xf0YjqpnbsLK1DdVPfe8d0WianKqxqwkV/XoFX1xX2+bnJ2cKd5RyV6UDboPay6mPOlhLcOGNDv3cXs/r9gl2ev0bw8l7zs8rQ5VeOvcPdVn247a1i7UTi9Dv91/Yyx947pYdbsCAr9MNrrvFo+vbf1+JbT63G9S9uwDWWofd9sdW4ijlgHEPr6w40Simsyq3s0xfy+vxqnP7IYjS0xX8Bi5+/u42jMh1oG9TRGp7iISxIXM72PJCaiFbvDa+PXmTTpxsANuyLHkTL91Tg9EeWYGepm5GVoUcq26YvdofPj/ve247v/N2+ofjmmRvxm/ftP7zBjbGNxu149je/9eXP0dDW2X11N9DbBt/begB3zd7ap7aX21/fjLZOP3IONkTf2ZBVctjTlYlisS6/CtM+3h3X52xq93l+JaBtUFsvP71grV6J59zW8bQyN35VMwpAQaV9nXxWSaDqIG3qQny686DtPmYQ7jhQh0c+2mXbe8Rk/SK0G/QT7aomUvDe8KJ91c+CrFKssvmC6o1X1xZ2f4H3pVdMWV0rmmPo5ePW/R9sx80zN7ja1xwTcKg++iyN0bg9El1+hZtmbsRds7ZE39nG/ppmbI9xuoVVeysd1/+84/Utrht561o6kDZ1Yfd0EZ1dfkz7eDdqLNVrD83Pxm/n7oi5nLHQNqitkxJ5wcsGy3i6e3bsPSHK6lptu7pFy5o8o6/625v2QynlWM0hIngryrwfTsEWPIioL1+N+6qaccYjS8K2/+b9HbhrVnhf8954YWVBL5dgC3Xp9JX4j3983qvHFlY1OXZbnL+tDNtK3AWEdRGOfVVNuOeNrTENvDK5rTY0/wZ29HJk6+VPrcaNM9x9EQHAqtxK3DVrK2auKujV6wXLqwic0LyyNtA2sWx3BWZvLMZjn+4J2c/84nPqTRYP2gZ1ItS16HF5Fg+XTl+JaZ+EX+JFu2qoCjpbeH5FAU5/ZElYf2jAuWpKKdX9IfY7fJaDBxH1tQGz1cMPh8l8p3294todQ3WBKfdQA654eg0uf2p1n14b6PlSNE9QHl6wC5/lVPaqK+LS3fbVZ1b9fSpU2RgIzZLavo8o7ll0OsBn/EFbz9bNL6NhHvaAGNJBbT2u72weXDPDLd1dEXYWGC0X73svMLJxX1VzdyNerc3Iz4cWZNs+fvKDi3DK7xejqrHdVbDFGn1LdpXH+IhQq/dW4p4Y+2p3GQetuT22L4V4VKVd9WygcbTMmPZ3xqoCzNlSgtrmDsxcHdtZoxko1lXvexMvutbXm9VUvSle5v7DIVcKZmOytcrKeiVu/n142VNN66B2s+JJ3wyMqo/echuWTo81zdlSEnX/619cj39tL+v+OTAyL3y/F41Z+Ey1TbEN//+ft7fFtL/pNaNr3Z2ztkbsjmk3a+NHWYH6enMoel1LR8jf5lXPrsVfF+XA71d4afU+vBRhxj5flz+sjtNOh89ve7Xx1NK9eHB+Nh6YtxNPLtkb9XlCGE/X6VN4ZtlebDL7eTt8DDYUVIf8HQSLtb7ePAvNr2hE2tSFtg3LwXxd/rAqmdfWFSJt6kLbKzxTuxG0vfki+d5LG/HU0p5jah7fXKM6sLvqyPI48wTbyw4QWgf1V6OMsOur/Zb1DlflupuG1MuhwaviPGzd2qbiNAw/ErOOzklbZxd2ltZ3n40DgUyw+zD/bVleyM//+eqmmMvTG+Z8I05+8lYGVuVW2g64sQ6uOe+x5SF/m7mHGvGPtYVYsvsQnliS2z3T4MG68BONRz/eja/86TO0dERuWDz14cUReyc0tcdeTWf+Pv7+WR6eX9lzNi4QfL6vBmlTF4aU64evbcatLwcG/rR0+EJ62VirtTp8fsxYVRB1mP9nOYG/74939Hyp201pcPcbGTjd0v5gnjDkVzY6fgYf/Vegy2tBZSBc/2/eTry7OfxE49W1hbZ135HaxrqvQCx5bE5mNmSD2mvtlj+qMperyhS7mJa0t+wmXOoLa1i6mVLVVO6yd8B6m/7FRdVN2FBQE/WxB+PQA8GtM//Q88G3zhS4dHeFcezDA2BzkbsRhtssdb3ZZfVYuvtQSPiZPV9aXMy3/maExtpI7SmVjW1Im7owrB7Z6e9WBPiB8YVpjtq1Pib9T5/hy48u7d5uXu6X17di+4E6zNpQhKeW7sWsDUUhjz9wOPQz9cSSwNQHr64r6q4bn5d5IKxMdnO3m1MK3/Ly593zwVuZv9YGYwKw9zMO4KEF2dho6Xr650U5tr00fBF6m5kfpXLLF7B5pRXvuYGCDemg1rU7Xjx9lFUWfScH1i8yJ9bJpoBAzwvzQxkrpRQa2jpRXB3fRQqCw9HnV8ivaMT3//E5SoICrNqmKsbp8t/K7mz8J29lhpx9dzdM2vzpLd8TWiUT6cIt12YmSVNOeWP3ay/K7qnTtz6/Kfj9B1/6B7N+sZh9sS/560rcOGMD/mrMPWM27s7NOID9Nc0hVUnWul5zmL6b8QvWVYyWRWnMtD7jf77qbnENs4rGboSrWf4tDkvSPbci37NBdMmePOsA0dAa/36tujG7GA0kb28uwSMfeTtqs8uv8KeFOdhcVIvLngrqhdKH50xyaE0Knt7V2pAHBOpj1+VXY4nDQKRYBRfjD//ahWvOPjbi/g98uDPi/U5fVEURvkj/d95OHDF6eMiZ/+2v24fl+xnhZ9RW1lWM2o06fMcutjabb3exEtI2Y06cH/0zvN938EXYA/N24tpzjkXquJEh+zz68W785aazo75OrIZ0UOsyWopCeR3SQKCrVbJNsPZlNZhIZ7kmsw53bV41vn7yJBzzhRS8sLIAz63Ij/LIyD7deRAXT5mIy55chUumTAy6R/DelhL8+7lfdP1ch5s7MGHMiO6fL/zzZ7b7PbM8L2ybUj1tONbqmSyX/b3dyD3UiFfWFuKk1LG48oyjUFbXisPNPa9XYVOltt7F2qI+v8JXHl+OGpsFOILPoudmlmJuZilOSh0Tsk+81i+1GtJBTUPX2dOW2W7/4yd7bLfHi1l3+ru5OzBxzAhkPvJtzHA5OCNSCPzi3azu2yuC6kqrm9oxdX52TKPmzn98OYqnXxt1v102CzIEGpFdv1SfmNUtT3zvbPzfh6HdRZs7ulDby9WO7EIasL/asp7RR6rj7gsGNVEfRBpCH40ZCD6XydaXRYwjVVPY+eviyL1knJ7z+RX5rrpzmraV9H7eb5M1pE3r8uO3mLRTLxNrz7F4L75hYlAbDsRhJBNRrH7XT7PFue25YvrHmt5P6eq28RUAnvssvMqnMU4z82UUR/8ScLuM3zf/ttq2x5R1TiKnOUb6SqteH988LTVhr/2NJ71bG5EGj+Bqih/b9HaJlVarCCWAXU+ZO17v3QROVtHmogHcL+PntltrPGdvDKZVUL/0w68kughEEQV3X3Pq7kZ94+UsdAOVVkE9akRSootARKQdrYKaiIjCuQpqEblKRPaKSIGITPW6UERE1CNqUItIEoAZAK4GcCaAH4jImV4V6M6vpXn11EREA5KbM+qLABQopQqVUh0A3gNwg1cFuv87p7re9/JTE9dLpC8uG6Dl9tr304/H2JF69BgdOzIZk8aOjL4jUT9w86n4EoDgwfilAL5q3UlE7gVwLwCccMIJvS7Q+JThKJ5+LZ5ZtjdkKsbMh/8NHV1+NLb5cOLE0RiZ3NPw2OHz46pn1+Lp/zgX558woXu7r8uP5KRhKK9vRVF1M8aMSEa7z48L0yZARLoHK5xy1Fhc/eVjcPyRo9HW2YWvnTwJx08YjdkbizA+ZThuuyjwfjYV1uChBdkoNGbx+tZpqfify0/CV4OG7NY0tWPUiCSMHpGM+tZOJA0Tx/CpbmpHdVM7mtt9uOCECcg91IiN+2rw5JJcTL36dBw9PgVXnXUMdpTW4cwvjg95z36/wpSHFuH9ey/Gq+sKsamwFtnTvgMRQZdfYZiEjppqavchJXkYkoYJKhraMXpkEmqaOtDU5sO4lGSkTRpjV8RuczMO4Nzjj8BJqWOhlOo+rvtrWnDixNE4alwKDrd0oPRwK/IqGrF01yG8fueFAAJzPG8qqkGHz4/LT03FlU+vweJffwPjU4ZjQVYpLjhhAk6cGHj9J245B3UtHahuasev5mzH6ceOw9O3novGdh9GJA1DZ5cf41KGh5Wvw+dHa0cXxo9K7n7fSinkVTThlKPGdq++8fyKfNx20fEQSNg8DdHkVTRizMhkJA8THD0+BUDg932kMdy6tbMLScMEI5KGobHdh3HG711EUF7fimPGp0BEUFTdHFhey6/wrdOOwvAkCfld7TnYgM8La3DHxSeis8uPlOFJSBomePijbLR1+vHjb0zBqOFJSBk+DEcZ5ahuaseh+jacfNRYpAxPQltnF55ethcXph2Jb595dGD1eREMk8AskV8YNRzjUoZDKYX3th7ASaljMXnSGCil0NzRhcMtHXhofjbevPsi5BxqxOgRSXhmWR6evOUcHDdhFOZlluL6876Im2ZsxClHj8Wz3z8PFQ3tyNx/GGccOw5TUseGHDulFOZmlOLlNftw+8Un4u6vT0Zzuw9nPboU028+GzdfcBzW5FWhw+fHz9/dhk0PXomtxbUYPSIJ6WlHorCqCacdMw6jRyQjc/9hfO+ljXj8hrNwxyVpAHoGHr1378U47/gjsK3kMC5MOxJdfoWU4Un4YOsBnHz0WNw8cyMmTxqDJ285B8eMT8E7m0tw16VpKKxqxqlHj0V5fRvW5Vdj8qTROHHiGBw3YRQqGtpw8lHjUF7fisqGdpx+7DgACPk8AoHj6rZfdqwk2tzKInIrgO8qpe4xfr4DwEVKqV86PSY9PV1lZPS9jykR0VAhIplKqXS7+9xUfZQCOD7o5+MA2C9RTUREcecmqLcCOEVEJovICAC3AfjY22IREZEpah21UsonIr8AsBRAEoB/KqWc1wgiIqK4ctXErpRaBGCRx2UhIiIbHJlIRKQ5BjURkeYY1EREmmNQExFpLuqAl149qUgVgOizdtubBMCbFSIHFh6HHjwWPXgsAgbjcThRKWU7v4QnQd0XIpLhNDpnKOFx6MFj0YPHImCoHQdWfRARaY5BTUSkOR2D+pVEF0ATPA49eCx68FgEDKnjoF0dNRERhdLxjJqIiIIwqImINKdNUA/WBXRF5J8iUikiu4K2HSkiy0Uk3/h/grFdROR54xjsFJELgh7zI2P/fBH5UdD2r4hItvGY5yV4qRCNiMjxIrJKRHJEZLeI3GdsH4rHIkVEtojIDuNY/NHYPllENhvv631jWmGIyEjj5wLj/rSg53rQ2L5XRL4btH3AfJ5EJElEskTkU+PnIXkcIlJKJfwfAtOn7gMwBcAIADsAnJnocsXpvV0G4AIAu4K2PQlgqnF7KoAnjNvXAFgMQABcDGCzsf1IAIXG/xOM2xOM+7YAuMR4zGIAVyf6PTsch2MBXGDcHgcgD4HFkofisRAAY43bwwFsNt7jBwBuM7a/DOCnxu2fAXjZuH0bgPeN22can5WRACYbn6GkgfZ5AnA/gHcBfGr8PCSPQ6R/upxR9+sCuv1JKbUWQK1l8w0A3jBuvwHgxqDtb6qATQCOEJFjAXwXwHKlVK1S6jCA5QCuMu4br5T6XAX+Yt8Mei6tKKXKlVLbjNuNAHIQWI9zKB4LpZRqMn4cbvxTAK4AMM/Ybj0W5jGaB+BK42rhBgDvKaXalVJFAAoQ+CwNmM+TiBwH4FoArxk/C4bgcYhGl6C2W0D3SwkqS384WilVDgQCDMBRxqh2ufwAAAIjSURBVHan4xBpe6nNdq0Zl6znI3AmOSSPhXG5vx1AJQJfNvsA1CmlfMYuweXvfs/G/fUAJiL2Y6SjZwE8AMBv/DwRQ/M4RKRLUNvVJQ7FfoNOxyHW7doSkbEAPgTwa6VUQ6RdbbYNmmOhlOpSSp2HwBqkFwE4w2434/9BeSxE5DoAlUqpzODNNrsO6uPghi5BPdQW0K0wLtVh/F9pbHc6DpG2H2ezXUsiMhyBkH5HKTXf2Dwkj4VJKVUHYDUCddRHiIi56lJw+bvfs3H/FxCoTov1GOnmUgDXi0gxAtUSVyBwhj3UjkN0ia4kD1QnIhmBRqHJ6Kn0PyvR5Yrj+0tDaGPiUwhtQHvSuH0tQhvQthjbjwRQhEDj2QTj9pHGfVuNfc0GtGsS/X4djoEgUG/8rGX7UDwWqQCOMG6PArAOwHUA5iK0Ee1nxu2fI7QR7QPj9lkIbUQrRKABbcB9ngB8Ez2NiUP2ODgen0QXIOgXdQ0CPQH2Afh9ossTx/c1B0A5gE4EvuH/G4F6tRUA8o3/zaARADOMY5ANID3oee5GoJGkAMBdQdvTAewyHvMijNGmuv0D8HUELjt3Athu/LtmiB6LcwBkGcdiF4A/GNunINBzpcAIq5HG9hTj5wLj/ilBz/V74/3uRVAvl4H2ebIE9ZA9Dk7/OISciEhzutRRExGRAwY1EZHmGNRERJpjUBMRaY5BTUSkOQY1EZHmGNRERJr7f3TPJER8qWH9AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial State:\n",
      "[['P' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' 'W']\n",
      " ['-' '+' ' ' ' ']]\n",
      "Move #: 0; Taking action: r\n",
      "[[' ' 'P' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' 'W']\n",
      " ['-' '+' ' ' ' ']]\n",
      "Move #: 1; Taking action: d\n",
      "[[' ' ' ' ' ' ' ']\n",
      " [' ' 'P' ' ' ' ']\n",
      " [' ' ' ' ' ' 'W']\n",
      " ['-' '+' ' ' ' ']]\n",
      "Move #: 2; Taking action: d\n",
      "[[' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' 'P' ' ' 'W']\n",
      " ['-' '+' ' ' ' ']]\n",
      "Move #: 3; Taking action: d\n",
      "[[' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' 'W']\n",
      " ['-' '+' ' ' ' ']]\n",
      "Game won! Reward: 10\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_model(model,mode='random')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 3.6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Games played: 1000, # of wins: 908\n",
      "Win percentage: 90.8\n"
     ]
    }
   ],
   "source": [
    "max_games = 1000\n",
    "wins = 0\n",
    "for i in range(max_games):\n",
    "    win = test_model(model, mode='random', display=False)\n",
    "    if win:\n",
    "        wins += 1\n",
    "win_perc = float(wins) / float(max_games)\n",
    "print(\"Games played: {0}, # of wins: {1}\".format(max_games,wins))\n",
    "print(\"Win percentage: {}\".format(100.0*win_perc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 3.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    " \n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(l1, l2),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(l2, l3),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(l3,l4)\n",
    ")\n",
    " \n",
    "model2 = model2 = copy.deepcopy(model)\n",
    "model2.load_state_dict(model.state_dict())\n",
    "sync_freq = 50\n",
    "\n",
    "loss_fn = torch.nn.MSELoss()\n",
    "learning_rate = 1e-3\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Listing 3.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4999 0.012652655132114887\n"
     ]
    }
   ],
   "source": [
    "from collections import deque\n",
    "epochs = 5000\n",
    "losses = []\n",
    "mem_size = 1000\n",
    "batch_size = 200\n",
    "replay = deque(maxlen=mem_size)\n",
    "max_moves = 50\n",
    "h = 0\n",
    "sync_freq = 500\n",
    "j=0\n",
    "for i in range(epochs):\n",
    "    game = Gridworld(size=4, mode='random')\n",
    "    state1_ = game.board.render_np().reshape(1,64) + np.random.rand(1,64)/100.0\n",
    "    state1 = torch.from_numpy(state1_).float()\n",
    "    status = 1\n",
    "    mov = 0\n",
    "    while(status == 1): \n",
    "        j+=1\n",
    "        mov += 1\n",
    "        qval = model(state1)\n",
    "        qval_ = qval.data.numpy()\n",
    "        if (random.random() < epsilon):\n",
    "            action_ = np.random.randint(0,4)\n",
    "        else:\n",
    "            action_ = np.argmax(qval_)\n",
    "        \n",
    "        action = action_set[action_]\n",
    "        game.makeMove(action)\n",
    "        state2_ = game.board.render_np().reshape(1,64) + np.random.rand(1,64)/100.0\n",
    "        state2 = torch.from_numpy(state2_).float()\n",
    "        reward = game.reward()\n",
    "        done = True if reward > 0 else False\n",
    "        exp =  (state1, action_, reward, state2, done)\n",
    "        replay.append(exp) \n",
    "        state1 = state2\n",
    "        \n",
    "        if len(replay) > batch_size:\n",
    "            minibatch = random.sample(replay, batch_size)\n",
    "            state1_batch = torch.cat([s1 for (s1,a,r,s2,d) in minibatch])\n",
    "            action_batch = torch.Tensor([a for (s1,a,r,s2,d) in minibatch])\n",
    "            reward_batch = torch.Tensor([r for (s1,a,r,s2,d) in minibatch])\n",
    "            state2_batch = torch.cat([s2 for (s1,a,r,s2,d) in minibatch])\n",
    "            done_batch = torch.Tensor([d for (s1,a,r,s2,d) in minibatch])\n",
    "            Q1 = model(state1_batch) \n",
    "            with torch.no_grad():\n",
    "                Q2 = model2(state2_batch)\n",
    "            Y = reward_batch + gamma * ((1-done_batch) * \\\n",
    "            torch.max(Q2,dim=1)[0])\n",
    "            X = Q1.gather(dim=1,index=action_batch.long() \\\n",
    "            .unsqueeze(dim=1)).squeeze()\n",
    "            loss = loss_fn(X, Y.detach())\n",
    "            print(i, loss.item())\n",
    "            clear_output(wait=True)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            losses.append(loss.item())\n",
    "            optimizer.step()\n",
    "            \n",
    "            if j % sync_freq == 0:\n",
    "                model2.load_state_dict(model.state_dict())\n",
    "        if reward != -1 or mov > max_moves:\n",
    "            status = 0\n",
    "            mov = 0\n",
    "        \n",
    "losses = np.array(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x120a077d0>]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWoAAAD4CAYAAADFAawfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3deZgdZZ0v8O8vLAIqAtrjRcAJzLghI1sbURzGAZR1ZGauKCjeUfFmVAZh9AGDjMNwkUXFAHGDsAgoJJCwCQkhISEkgWwdsu+dpLMn3elOOkl30tv53T9O1ek6dare81adc/pU1/l+nidPTtf6nuquX731rqKqICKi5BpS7QQQEZEZAzURUcIxUBMRJRwDNRFRwjFQExEl3KGVOOgHPvABHTp0aCUOTUSUSgsWLNilqnVB6yoSqIcOHYqGhoZKHJqIKJVEZGPYOquiDxH5TxFZLiLLRGSMiBxRvuQREZFJ0UAtIicA+CGAelU9DcAhAK6qdMKIiCjLtjLxUABHisihAI4CsK1ySSIiIq+igVpVtwK4F8AmANsBtKvqZP92IjJcRBpEpKGlpaX8KSUiqlE2RR/HArgCwMkAPgTg3SJyjX87VR2tqvWqWl9XF1hxSUREMdgUfVwIYIOqtqhqD4DnAXyusskiIiKXTaDeBOAcETlKRATABQBWVjZZRETksimjngtgPIB3ACx19hld4XTFMnHpdrR1dFc7GUREZWXV6kNVb1PVj6vqaar6TVXtqnTComrZ14UfPPUOhj/JjjZElC6pGeujpy8DANi650CVU0JEVF6pCdRERGnFQE1ElHAM1ERECZe6QM25eokobVITqEWqnQIiospITaAmIkorBmoiooRjoCYiSrjUBWoFaxOJKF1SE6gFrE0konRKTaAmIkorBmoiooRjoCYiSjgGaiKihEtdoGYXciJKG5vJbT8mIos8//aKyI0Dkbgo2IWciNLq0GIbqOpqAGcAgIgcAmArgBcqnC4iInJELfq4AMA6Vd1YicQQEVGhqIH6KgBjglaIyHARaRCRhpaWltJTRkREACIEahE5HMCXAYwLWq+qo1W1XlXr6+rqypW+yFiXSERpEyVHfQmAd1R1Z6USUwrWJRJRWkUJ1FcjpNiDiIgqxypQi8hRAL4I4PnKJoeIiPyKNs8DAFXtBPD+CqeFiIgCpK5nIhFR2qQuULMLORGlTXoCNZt9EFFKpSdQExGlFAM1EVHCMVATESVcCgM1axOJKF1SE6g5CzkRpVVqAjURUVoxUBMRJRwDNRFRwqUuULNnIhGlTWoCNSe3JaK0Sk2gJiJKKwZqIqKEY6AmIko42xlejhGR8SKySkRWishnK50wIiLKsprhBcADACap6lec2ciPqmCaSsJGH0SUNkUDtYgcDeA8AN8CAFXtBtBd2WRFx0YfRJRWNkUfpwBoAfBHEVkoIo+IyLsrnC4iInLYBOpDAZwF4A+qeiaADgAj/BuJyHARaRCRhpaWljInk4iodtkE6i0AtqjqXOfn8cgG7jyqOlpV61W1vq6urpxpJCKqaUUDtaruALBZRD7mLLoAwIqKpqoEyj7kRJQytq0+rgfwlNPiYz2Ab1cuSfEI+5ATUUpZBWpVXQSgvsJpISKiAOyZSESUcAzUREQJl7pAzapEIkqb1ARqViUSUVqlJlATEaUVAzURUcIxUBMRJRwDNRFRwqUuULMHORGlTWoCNXuQE1FapSZQExGlFQM1EVHCMVATESVc6gI1x6MmorRJTaAWdiInopRKTaAmIkorBmoiooSzmuFFRJoA7APQB6BXVTnbCxHRALGdMxEA/lFVd1UsJWXCqkQiSpv0FH2wLpGIUso2UCuAySKyQESGB20gIsNFpEFEGlpaWsqXQiKiGmcbqM9V1bMAXALgOhE5z7+Bqo5W1XpVra+rqytrIomIaplVoFbVbc7/zQBeADCskokiIqJ+RQO1iLxbRN7rfgbwJQDLKp0wIiLKsmn18UEAL0h2HNFDATytqpMqmqpSsNkHEaVM0UCtqusBnD4AaSkJx6MmorRKT/M8IqKUYqAmIko4BmoiooRLXaBmXSIRpU3qAjURUdowUBMRJRwDNRFRwjFQExElXCID9c9fWYFxDZurnQwiokSIMnHAgHlk1gYAwJX1J0Xel7OQE1HaJDJHTURE/RioiYgSjoGaiCjhGKiJiBIudYGaVYlElDapC9RERGljHahF5BARWSgir1QyQURElC9KjvoGACsrlRDK17KvC2ffMQWrduytdlKIqMqsArWInAjgMgCPVDY55Jq2aidaO7rxmNP5h4hql22O+n4ANwPIhG0gIsNFpEFEGlpaWsqSuDjYMZGI0qZooBaRywE0q+oC03aqOlpV61W1vq6urmwJrHV88BCRTY76XABfFpEmAGMBnC8if65oqggCTqtORFlFA7Wq3qKqJ6rqUABXAZimqtdUPGVERASA7aiJiBIv0jCnqjodwPSKpIQCsYiaiFKXo9a0hLaYRdR7OrsxesY6jstNlCKpC9S17qcvLMVdE1dh7oa2aieFiMqEgTrhomaM9x7oBQD09IU2eSeiQYaBOqHYOI+IXAzUKcUiaqL0SF2gTluAilo5KsyKE6VO6gJ1WggjLhE5GKhTKmUvFkQ1jYE66RhxiWpeagP1D55agKEjJlQ7GbGx4IOIXKkL1G4GdOLSHVVNh43O7l589L9exWvLy59W9kwkSo/UBerBZFNbJ7p7Mxg5eU3oNlHDLSshidKHgTqhGG+JyMVAnVIs+CBKDwbqBDB1aola1lyJjPhN4xbjtNteq8CRichGpPGoqbxM020lqehj3IIt1U4CUU1LX46a7/xZvA5EqWEzC/kRIjJPRBaLyHIRuX0gElZLTKUb0Vt9lJSUqujs7sWc9a3VTgZRYtnkqLsAnK+qpwM4A8DFInJOZZNVG0xBdbDOQr5m577I+9w0bgmuGj0H2/YcqECKiAY/m1nIVVX3Oz8e5vzji3XCVWNKsknLtuNL983Ay4u3Rdpv1Y69ALI5ayIqZFVGLSKHiMgiAM0Apqjq3IBthotIg4g0tLS0lDudg9rsda1o7+yJtW/UDob+fPgbq5uxrmV/4Lbltmbnfuf/6LlqIH1D1BKVi1WgVtU+VT0DwIkAhonIaQHbjFbVelWtr6urK3c6rSVtctuDPX24+uE5+Pbj80K3CUpxucqav/3H+bjg12+W52AVwt6URGaRWn2o6h4A0wFcXJHUpFBvJhuGV++Il8scTJgjJqoMm1YfdSJyjPP5SAAXAlhV6YSljSmGmfKTcWNfNYNm3Pwx4zxRMJsOL8cDeEJEDkE2sD+rqq9UNlnpYRO0yhmgqlmMELfYiQUfRGZFA7WqLgFw5gCkJdWCcrgMUPlYdEIULHU9E5N2s5eawY07rnQ1rkPctt+sSyQyG9SBurcvg90d3dVORkXELcKoZswrtcVN0lrsECXFoA7UP3luKc68Ywp6+zLVTkpR5RwhL/EiPmQGay9MooEyqAO12wOuL8GBrtQR8mK3+oi5HxElz6AO1K6kxen2Az1oaGrLW1ZqD0Pr/aqYOS3195C03yNRUgzuQJ3QN+bvPD4fX3lwNrp6+yqaa04LViYSmQ3uQB0gCUFv2dZ2APk5xOB0VS6KV6Pcu/QWLuVJB1HaDOpAnfSMWCmBJ37Qq2KHFwZaoooY1IE6qSIHWSfANTbvw9vrdpU9PQMt6Q/QpFjXsh9DR0zAG6uaq50USrhUBOqk5uTymuRZpPHCkTPw9YfzR5B1j9HV24eDPX0Rzj3wSj1nrbWjXrBxNwBgwtLtVU4JJV0qAnXSeJvkWeWuA7bxN+s7955p+PjPJhU/VITsbOv+LnT3lr8NetQ3ilof5jSpGQ1KjtQF6mp1Hmk/0FMQ9PIrE0tL16795e+BefbPX8cNYxeW/bhx1VrAqu3HE0UxqAN1kjJip98+Gdc+MR9A/DLqcrINeq8u21H+k0eUoF9jVdRakQ9FN6gDtSspf+gz1+ZXBHpT5QbOLbs78VZjdjurNtYD1FGGBl6tF/mQPZvxqBMrqWNEmFJ1/q/fRHdvBk33XGY+RjK/WkXU0ncNlIx8BiVYOnLUCf1D95aXu5+CKu+iJr+7N4On5m5EJmPaswoXJam/iISq9ecT2bOZiuskEXlDRFaKyHIRuWEgEmbDNie2c+9BDB0xAS8t2lrZBDlsX2mtZn8JiH0PvbkOt76wDOMXbAk4t9WpKyrum06txvka/doUgU2OuhfAj1X1EwDOAXCdiJxa2WRFoyGfXWt2ZieWfbZh84Ckx5SWctjd2QMA2HuwJ9J+izfvwfVjFqLPmBOPL+5Rk/BwqYZa/d4UXdFArarbVfUd5/M+ACsBnFDphNkI+jsPnvJKQtdVQnC6Ys4nGNTGOmYl5Pf+vAAvL96GnXsPxkpLpSWlUnigpW48ciq7SGXUIjIU2fkT5wasGy4iDSLS0NLSUp7UWXpm/mZjEHYD20DfD7bnM92opmOYHkrG8xVZv/dgD4aOmIAn3m4qeqzZ61oxdeVO59zxJLVSuNKYoyZb1oFaRN4D4DkAN6rqXv96VR2tqvWqWl9XV1fONBZ1xysrMG3VztD1A34/BJwwKDiayrJNaa7E95m1dhd+9MwiAECzk+N+cnZT0f2ufngOrn2iAQDLWuPidaNirAK1iByGbJB+SlWfr2ySzHa0Hwwc82J/V/FxMNxX68bm/Vi8eU/Z0xZwQvPqErP4xum9ApaZAvw1j87F8wvzK1sHuszZvRyPztqQGyo2zWr1TYKis2n1IQAeBbBSVUdWPklm59w9Fd/647yC5cY/eV/Rx4Uj38QVv3ur7GkzpSVuTA4KxpWf2dy9YKWdx5b/+9zxygpc/ptZA3PyBGARNRVjk6M+F8A3AZwvIoucf5dWOF057Qd68MLC/GZoc9Znp7nyFh2YgleuMjHiuVUVq3YUlPLY71/kjMaijxJ7LQbtbt1ssMQccVxBu3d09WKt02qnXF5evA2vJmDEOpZRky2bVh+zVFVU9VOqeobzb+JAJA4AfvzsYvznM4tzTexcezrtBynK3RARA8kTbzfh4vtnYu761kj7VbprsHv82BPf2lZyxjx+kL6MYt6GtsB1pqv1ncfn44v3zShjSoDrxyzE9596p6zHLAUz1FRM4nsm7th7AAAKyqX/6bez8m5w42zfzv9Rm38t3ZrNTW9q64y0XzkFt+yw2K+Ec+auV8DJO7p68bWHZmN9y/7C/QwJGzV1Lb760GzMbwoO1mHnmxsS3IlqSeIDdVgA3tx2IH87U9GHmwONGL1KbddbWs7Vopmds+OMNS1YsiVbOVqO8mvTG8H01S2Yu6EN905eHZoeV19Gc51r1jZn34gC23DHTHTTrg7MWDOwTUErge2oqZhBMyhTsb9lY3O2uHVj6u4fLZDYbl6uaRH/z2PZytViAz1FZRxJxKJ8/JO3TcKxRx2O2bdcUHILh6AHyBfunQ6g/N97oHD0PLKV/By1ZZA1VyY6xwiILl29fbh74kp0dPUW3T+qcuSTzMMuxesoYzyfeq9X4XpT5yF/eg72ZLC9PT8HbezAUyRdaZXir0ZlkvhAbSTej8VbUATdEE/P3YSHZqzHb6Y1FqyL3Y444vamgBh8/PCiHHMRkF16rB56hqsTuL/hd1DpMvekYn6abA3uQO1l0ZA6KLD19mUX9vQFDD/q7BC/uVqx5nnxjlt6ObRhXd7n8Dbc0bu2E5DNGPzPX5bnL0zjU4jKKvGB2lRsISGfIx3fF3iadnVgU2u2lYf6trE/ZvnCUjU6ypjLk0svax7I/ZLmpy8sxePOGCosoiZbya9MtO6kEa/ow+UGtqAKqtjjK8fayz1n3HU2gzKZyrY9kx3ELE8OHvEv3jUUyaYjHWE6WK2OGkj2Ep+jdhUb1MgqsAXlymM23TMZqIxS7KFTi6TQVLwRdyTCuJMkSN56xcjJq7EuoA33YMSxPsjWoAnUpYg7Sp0bOKrVpdo5SsGSuOXE/UeMNphT3rkNW9p836gVoN79WvZ3YdS0RnzzkYJRdhNl8vIduOX5pdbbB7agUUVjc3m7ztPglZpAbXWzR9yvmjOWmB8uFs0you5nyeYNJCjt/cVP5gKo0POhv/yjuy/ZRQXD/7QAY+ZtKrqd6e/kpUXbcOHIGcbhe6l2JD5Qm9r05m3n+aP3TzVlc4ygYoRyD0Pa1hE8PkncMspKjPWh6s2xmytww48RbT+bdareDQuP39uXwS3PL8XWPQcK1lXblBU7je30g34fK7Znhy9YszMdxTxUmuQH6hgtFS4bNTP3ubO715ibM43X1N/qozxliWfdMcV3bvtu4nn7VeDtIX87mwrJeNzvc85dUzHiuSWB6/LT4vlsaGY5e30rxszbhJ+MX1K4sooam/fj/z7ZgJufK0yX9QOKal7iA3U/c3Mzby561Y7+sr2ePi3SQcTiVT5iSsuhHJVvofsZ16lxO1Nuu3+boKKP/Ou8Y+9BjJ2/OXT74OOGp8tfn9DZ3Yv2A9Em/wWAldv3ljVX7uak3SafQQLfqKyKiqhWJD5Qm27hPZ39N+LLS7YFbpPJqF0rhqAbIvY9UtnQHrvNeMTtzNerkLHJnuW5C/frD/CmNvV+5/1yOk6/fXLk813ywEyce8+0yPuFMb7N+S7KS4u2YlyD8/Aa4AmZKdkS3Y66OcJs2d29hT0LAaBPi+QQ3XWGsStil3wYbrLNbZ342UvLQs/d67whzFgbPjqcbXvoYuv8bafjdl+3Ea0qEXk5yyjjcO/a3xUxZZUxxOKNzV13w9jsnJVX1p/EzjCUx2YqrsdEpFlElg1EgryG3TUVB3uyAbhYYHh7XfDg/t4cdSDDytzrtBMtbF+LbW6yG8YuxJuGITpb9mUfUj1BLRwMN7+540/wOrcIws/0Sh5ZzArKoN6ngzGXmTE9hQxr0tIjk0pjU/TxOICLK5yOUG7td9w/19nrWzFlRbaJkzGXaTiGG9+ivhabjhl849rtG8ff/HQidjuz4viPvXpHfntdm1yzqZw4SNyKU+86m/LxIBt2dWDoiAkVm9BYVXHt4/MDx8a2SbO5LiBaWrp7M8by8KTryyi+dN+bmLRsR7WTkig2U3HNADAg02w07epAa5lfWW8Yuwj3v742dL2x6MNwk3R29+Ivi4PLxaNW6PnPM2d9K8bOC87lxtWX0VyZ/q59+dfYH0TCMuWz17WWnNOL2hrFWyZumvvSlJqpK7MP6hcXbTVsFV9XbwZTVzXju082FKyzGdUxcF3MeT5vfWEpzvvVG2jvjF6RmgT7DvZgzc79uHn84sD1zfsO1uRbRqIqEy9+YAZGz1hfseP721cD3psloB21W0YdcKyfvbgcPxyzEAs37Q79w4n793TV6Dm5Nwm/eRvaSh4yyT8dVthDw5/8x97aEHtC3rjzVuYFupjHCDxWibp7M9jhG2s7sKQoZi9S/37tB3rQ2Lyv4O3Hz63T6OwJb7c9mBzs6ctl3ja1dmLYnVPxhzfXVTlVA69sgVpEhotIg4g0tLTEmx5piAgyZQ56eccIWGbzmh8UgLa3Z8uqO7r6CtZVsiLoqw/Nzn2Oe0kCRnTNE5b8YpMKlNrEsVgFaOyZegw2tnbgn34zK9JkyQBw8/jFOOfuqejq7euvMDR19ze04Q8ivmOefvtkXDhyBi66326i37SMI3Llg7Nx9s9fB4Bc/dCbqwf/9GtRlS1Qq+poVa1X1fq6urp4iREpGkRKEViRZcjx2AbvsO1MrTJKKS8VU7QsviqvJYx/27x21AXHKFIx609fgKfmbcKW3fllqLaTPsQtdjGN2fL7N9Zh6db2wDLRnr4MVoa82by+shlANmftHjeo3sGulUx4EC9Wl1FwpASVClz39Dt4KGLu1/+3sHRre/+6mOX2aZCooo8hAkOOuvTfjqmDSFBQ1YKt+r3V2Bq6XzlzM3/13ndFOr5NIO3L2Gap7QJ6//rsMrelTlC6Fm/eg6tGz8ktf8lTblys+aS3eV5vhCe6qQjL9DC9e+IqXPLATDTt6gAAtO7vwsGevrxjeQOpMSMQcHyrYqSQoQ1a9gXX5bhbF/s7yGQUN41bHPogKocJS7bj7ldXRdrHOHOQxTZpZdM8bwyA2QA+JiJbROTaiiVmiIQGZLflRimiVmTZbJPR8FfxBRt3o6Fpt03SQp103FGh68pV9OHPRYc9CPIq9AJO/qqTKx01NbzyFgB2e8Y8uWHsoshN/lSBX09ZE2l7wLIM3WPR5uzvrrUjGxTP/vnruMYZuS+ozN38oIn29mOqTHzsrSZ8+s7XjcO9FrukG9s6MW7BFnzvzwuKbAl87aHZ+Mxd2eKH3r4M/jRnY6QHJQCs2bkPP3p2kfV+QQ+xIUNqtxOQTauPq1X1eFU9TFVPVNVHK5YYkdBXvUdmbSj5+GG59TB2w3ZqaKuJ/3h6Ib5hMSTngZ6+0OBmyqXF5c9RhxbdqP/n/qKPoA4lH/3gewEAnzrxfcbzh11W4/X25eYb/BWiFs3fhgRcOJtRAL1/kw0bd/vWmYqKwnPUzfsOomFjeGMq02v+TKfC0G2GN3TEBFz31Duh27t+MWkVho6YACD79urd/r4pa/Avv38rt603oM7d0Iade7O/76fnbcLPXlyGP77VFH6iANc/vRDPv7MVa5vjDzLV/xZTWqTe3n4gL7Nga/m2duw3DK5VSYkr+vCXn5ZT9PGbw1+Zvft5dx3XsDlw/kWTto5ujAzJIbbs78LmtuB2seZ2zuErew0Fn4rw4OItJw4a1e20E44GAHzhY39VsM6bS/en27ZSzb0OfRkNLTYIWu7e2A/NWB8a0IPeIvzBLHCdIe156fFteNmoWXjozfWhxzC95gcFrAlLt+f2yNvI4w/T+8uL3e/rHuOBqWuxcFO23qRpVwf+9tZX8eLCwuaMe53xU/YciBbo+svx49/f/ks5a+0uPLdgC4Bs65Dv/WlB6L3i9dm7p2GY84Zgqy+juGzULFz7+PxI+5VLwgJ1eNFHOZjKk4PO6g7q020IvArNu5FvGr8Eu/ZHf1qH2dx2AH//yzfylrlBPWpLApe/maL/KKZWHya5YFfkLcD2Zh05eXUuLV29GfxwzEIA7kBb9ryna/a3ITfs5w9meesCctTBx0DgdmFlzP3Hd9IX+JCwaKlU5AqZjr96Z7YJYH/w9+6X/5axsbUDf5i+Ds82mNv9+9M8dMQE3DlhBQDg7okrMXTEhKL3vn9gr2senYsfj8u2t35jVTMmLd+BOyesNB7DFdjj18BNm/tGNdASF6hN9Vx3TbT7JYTZ3HYAK7b5Kk8Mf7DznfJlN+cTJJMZfJUb/kDtzYW81bgr90dcMCYIkBfFD3TnN00ckss1mc8fWvThWzNqWiP2Oa+aI55bkve2Va5Zd3I/Bx3PJkdd7OFlMT5JcPFW8YrGoIeEbT7HPUbQkAj9QdVQOeqs+odfTccvJq3CzUWGlx0ypDB9D8/MFmeOnrk+t86miadNZq5lX1fRh2Ex3rbytg/mSklYoDZfiHJ0hrnUM1Y1YFeT3NkdXi6VUbW+OcqtWFFO2EBV/ms8c+2u3Od//9OC3AOxoOhDNS+n9l8v5g//EhZAbntpWa61RNCBbYJuw8bded83qLzZxuIt+c0i3bT2BuSw+os3wgNp0VxgnER6BFeAhwf/3HPHOXH7gZ7AojjT9TM9cE1BfM76VgwdMQHbDME/6P521xV7EA+xeOi5Pn3n6/j0nfbFG929Gby0aGve97rl+aU45+6peRkSd/XIKWtwyi3Z8v6+jOJHzywq2hmpFIkK1NvaD2KcU+Y0UMLKEL1MOcSo7VzLyX/qFdv25oppMhre+mLvAXOFSKOhwsd7Ay3f1u5bF3wjPTF7I15c1N/dPvShWORaZmLmqL0337//Kb+VwwtOOeyDvva+Czbuxpz1bc7+hce0aee8dmf/jWuuTzCsM+bmw/d003f67ZNx4zOLfMc0t4c3BVVThd7/ezlblPGLSYVN8kz7HeI5n7Eoqgzl3GF+O20tbhi7CK8t729PP2VF9nNXb2GntlFT1+Z+943N+/H8wq24fsw7ZU+XK1GBeqD19GWscjzmysHqReo/z85vJnXpqJmYuDT7x/XM/E1o3hc8TGxQ2WMQU4AC8idoAPpvpD1FxpnwBzfbdud5gTqvclKN51zf0lH02P5JBv73H97OffZPVLu9/UCurNsUWr543wyMmra26Hadht6tpnqVoIdEUPCesCT/951Rc47a1IHHVD7uDnswYcn2vIdU9pjhuWH3fNv3HMTCTeFlwHaDhRWunN/UhqEjJgQOVvV24y783W2vYZ3zN+ItLrFtDug+OIOGqCiXmg7U3gs7s3FX6HZbdocPbZpRYMmW9tD1lbSvqxdn+qb3cv1N3XuMOWMT96oUjltt3s+9kYpN7Oo/rjdnvrE1OKh292by6i+8cebJ2Rtzucag8LN6Z/Arqbc4pv1AT0GZu2uTryXBX7xvB0XKVZ9/Z2tuuzDzmgqb6fVXzCLX4cY1ycn1BTWfdBeZAnFG8ytj/eO+25RRm98yFY845c/9x+xPo597vi/cOx3XPlE4sJX/3KZmkEGufDA77MKUlYV9MUZOWYN9Xb25v5GfvbQcv3ujMXtMZ5ti4Xcg2nfXdKD+85yNmOjkLlv2dRVWNHqETWLQ2d2HrXuqN6zkvoPBxRhnfvgYvLMpXjf1dmfci67eTF6TrlmNu0IrubbuOWBdqerf6oATMDOK0Fr7Ll95uzcdb6xuzn0Oei1eHvJ7vdg3bsYX73szNM355+7/bHtvRr2J3XNs2NWBL9w7PXAb/3dd27zfqnNPRvObN+49mP82YSz6sKhUy2jh+fuPWbj9kIC0+hedfccU3PaX5QDiv8MGjefS/2DrX/ar11bnb6Nq/L6m61UuNR2ofz5hJaau6r/J/RWNXsPumhq4/LFZGxI5AM5PnltafKMQ25ya7q7eTEF5Y1gAOPeeacbhZL38f8/um02fatGcWi4d3jR5Ps9cuytvCq5VO8Ifvk2+V2HTm5OXN7easXzdjTrok3uG7f4R+jz8Z84EdL4K8sDra9HreT25cGT+A8v9egs37Sk4nttRxFeiBdQAAAmuSURBVLYpHZB9ELhb+98OALuK4daObixwmsbFbcIb9MbkHisoDe6yt9a14iO3vhp4zExGreorSpXoqbgGgxXb99bUtEmmrxrWyqQY9w88G/RMObXgykR/Lt9b3nzx/eEP30pwe/75dXT3Gbt8+wO+TcVZUNGUu2TF9r347bTGwP1+P32dsTzVPXdndx8e9fUI/q1TLDB2/mbcfsVpocfw5lA/9T/9D86bxi/BlfUn5X7u7cugI6BVVdQKVpt1gWX6zv/++hag/zoETQjhOuWnE/HX788O8xBWJ1QODNRlsHEQzqjxFU9lWRS2M4bH0ZdRYwDxrpruGeoy6NW5krzNGaO87v7E0NbY35Ei12QtQjZNtb8o7OsP5w9dcOWD+b/vfYau0N7c5cMzg5vEdvVmCopMvGz/TO6dvCakUjR8n6AiNncgsMmGMYH8v6v1LfutiqT8naT83Ps/aDCycqnpoo9yiZuTrKa4PawqGRTHL9gS2DXdFfbK6w45OlC8c11GGS5gpyHH5b2ut7+83Or4/sBj6h043zc42NNzwyt8vYHaHeMDKByx0JtT9ntliV3LIn8TT5f5TaJw2U5PHZL3fvQOX+s/5vm/fjM05966vyvXw9iUox4oDNRlMNh6JpbidqetbDFxeoVNWr7DOHlwWMXpQFi2NTigTFtl/5AI6lTjGuKJ1H98qyn3xrDO0LTwpUX5U8E9/naTdVrCqGrow/hrnuFpiynWRNPlfTvxMl0rf8DNZPLvwFtf6K+f+aWnjiVojJuwIX/dyQqCNDZXrmNLGAbqEP4OECavrxjYHF01LbKc8CBKrzBb/pYfA+ny38wKXL4hoHIsjGkwLP88ljZFB9MrMNPJprbO0OKtBWUa5+JVi3b8pjeJ3Z09eUVCK30Vxt5WQF59fYrvPpE/qNKyrdHH4/7aQ/YPrHJhoA5xj2HAc38HgrB2upR+YyJMQmx6yxju6zFZyboAk+zDsLJviN9/qngPPtNDra2j21c8lD+MQ9igaD2ZTFmKyVpjDJFaKgbqGK57unJdRWnwaQrppFMK23JRfyVhOc77viMPK+sx45qzvjV03ZOzN+Y+/+jZRaFl3W2eZpFu56NKshlmNQ6pxLCi9fX12tAQ3sMoTFjTJiKiwaLpnsti7SciC1S1PmidVY5aRC4WkdUi0igiI2KlgoiIYrGZM/EQAL8DcAmAUwFcLSKnVjphRESUZZOjHgagUVXXq2o3gLEArqhssoiIBifbYQWisAnUJwDwVm1vcZblEZHhItIgIg0tLfGaDT33/c/lPv/LmSfgtRvPw7Chx+G4dx+OH3/xo7l1y2+/CFcP+zAA4NrPn4wbL/wIHrzmbDxw1RlouucyzLjpH/HV+hMx6ca/x4a7L8WH3ncEAGDxf38J82+9EN/4zIfx3Pc/GyuNUX359A/hHz5aZ739779xVu7z/zr6CDz5nWF44jvD8PXPfBj3f+0Mq2Nc9MkP4lufGwoAeN+Rh+HU44/GJ44/GicccyQu/9Txoft92DPj+d9/5AN56+7+17/L+/m7nz/ZKi0md/xzcBfkQ8vUq+ar9Sfm/Rz2e/jXswr+nEtmmuD3Pe+K3iH46e9+Bl8+/UO5ny/4eOG8lDauP/9v837+0qkfjHUcCvbuww/JaxNfLkUrE0XkSgAXqep3nZ+/CWCYql4ftk/cykQiolpVamXiFgAneX4+EcC2kG2JiKjMbAL1fAAfEZGTReRwAFcB+Etlk0VERK6ihWWq2isi/wHgNQCHAHhMVZcX2Y2IiMrEqlZDVScCmFjhtBARUQB2ISciSjgGaiKihGOgJiJKOAZqIqKEq8joeSLSAmBj0Q2DfQBA8LQPtYvXpBCvSSFek0KD6Zr8taoGdp+tSKAuhYg0hPXOqVW8JoV4TQrxmhRKyzVh0QcRUcIxUBMRJVwSA/XoaicggXhNCvGaFOI1KZSKa5K4MmoiIsqXxBw1ERF5MFATESVcYgJ12ifQFZHHRKRZRJZ5lh0nIlNEZK3z/7HOchGRUc61WCIiZ3n2+Tdn+7Ui8m+e5WeLyFJnn1EiUv5pJspMRE4SkTdEZKWILBeRG5zlNXtdROQIEZknIouda3K7s/xkEZnrfL9nnCGHISLvcn5udNYP9RzrFmf5ahG5yLN8UN5rInKIiCwUkVecn2vnmqhq1f8hO3zqOgCnADgcwGIAp1Y7XWX+jucBOAvAMs+yXwIY4XweAeAXzudLAbwKQACcA2Cus/w4AOud/491Ph/rrJsH4LPOPq8CuKTa39nimhwP4Czn83sBrEF2AuWavS5OOt/jfD4MwFznuz4L4Cpn+YMAvu98/gGAB53PVwF4xvl8qnMfvQvAyc79dchgvtcA/AjA0wBecX6umWuSlBx16ifQVdUZANp8i68A8ITz+QkA/+xZ/qRmzQFwjIgcD+AiAFNUtU1VdwOYAuBiZ93Rqjpbs3+RT3qOlViqul1V33E+7wOwEtn5OGv2ujjfbb/z42HOPwVwPoDxznL/NXGv1XgAFzhvDVcAGKuqXaq6AUAjsvfZoLzXROREAJcBeMT5WVBD1yQpgdpqAt0U+qCqbgeyQQuAO2Np2PUwLd8SsHzQcF5Pz0Q2B1nT18V5xV8EoBnZh846AHtUtdfZxPs9ct/dWd8O4P2Ifq2S7n4ANwPIOD+/HzV0TZISqIPKDWu53WDY9Yi6fFAQkfcAeA7Ajaq617RpwLLUXRdV7VPVM5Cdn3QYgE8Ebeb8n/prIiKXA2hW1QXexQGbpvaaJCVQ1+oEujud13M4/zc7y8Ouh2n5iQHLE09EDkM2SD+lqs87i2v+ugCAqu4BMB3ZMupjRMSdkcn7PXLf3Vn/PmSL2KJeqyQ7F8CXRaQJ2WKJ85HNYdfONal2IblTyH8oshVAJ6O/MP+T1U5XBb7nUORXJv4K+ZVmv3Q+X4b8SrN5zvLjAGxAtsLsWOfzcc66+c62bqXZpdX+vhbXQ5AtN77ft7xmrwuAOgDHOJ+PBDATwOUAxiG/4uwHzufrkF9x9qzz+ZPIrzhbj2yl2aC+1wB8Af2ViTVzTaqeAM8v4FJka/3XAbi12umpwPcbA2A7gB5kn+DXIltuNhXAWud/N7gIgN8512IpgHrPcb6DbCVII4Bve5bXA1jm7PNbOL1Ok/wPwOeRfcVcAmCR8+/SWr4uAD4FYKFzTZYB+G9n+SnItmBpdALUu5zlRzg/NzrrT/Ec61bne6+Gp7XLYL7XfIG6Zq4Ju5ATESVcUsqoiYgoBAM1EVHCMVATESUcAzURUcIxUBMRJRwDNRFRwjFQExEl3P8HwtMUFrRNqoAAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initial State:\n",
      "[[' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' 'W']\n",
      " ['+' ' ' ' ' ' ']\n",
      " ['-' 'P' ' ' ' ']]\n",
      "Move #: 0; Taking action: u\n",
      "[[' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' 'W']\n",
      " ['+' 'P' ' ' ' ']\n",
      " ['-' ' ' ' ' ' ']]\n",
      "Move #: 1; Taking action: l\n",
      "[[' ' ' ' ' ' ' ']\n",
      " [' ' ' ' ' ' 'W']\n",
      " ['+' ' ' ' ' ' ']\n",
      " ['-' ' ' ' ' ' ']]\n",
      "Game won! Reward: 10\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_model(model,mode='random')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
